# coding: utf-8
import pandas as pd
import numpy as np
from copy import deepcopy, copy
from joblib import load, dump
from src.Logistic_Bandits import logistic_bandits
from src.Linear_Bandits import linear_bandits
from src.utils import benchmark_sample_action, benchmark_cal_reward_costs
import os

####################################################################################
####################################################################################
############################# AREA OF INPUT PARAMETERS #############################
####################################################################################
####################################################################################
##### Environment Parameters
PATH = "." # Root directory, should be the same path this "README.md" file locates
PATH_DATA = f"{PATH}/data" # Path for data
PATH_MODELS = f"{PATH}/models"  # Path for models


##### Parameters for Bandits
size_norm = 50000 # T, use 50000 to reproduce the results
budget = 1600 # Budget constraint, to reproduce, test with 1600 and 2200
random_seed = 1989 # To reproduce the results, run 10 times with from 1989 to 1998

####################################################################################
####################################################################################
############# Create Output Path, Load the data and Conversion Model ###############
####################################################################################
####################################################################################
### Load Data and Model
dt_raw = pd.read_parquet(f"{PATH_DATA}/dt_env.parq")
model_conversion = load(f"{PATH_MODELS}/conversion_model.pkl")

policy_optim = load(f"{PATH_MODELS}/budget_{budget}/policy_optim_static.pkl")

dict_hyper = load(f"{PATH_MODELS}/budget_{budget}/dict_hyper.pkl")
print(dict_hyper)

lmd_logistic = dict_hyper["lmd_logistic"]
lmd_linear = dict_hyper["lmd_linear"]
Z_linear = dict_hyper["Z_linear"]
dict_eta_OCO = dict_hyper["eta_OCO"]

##### Create the folder for the output model

if os.path.isdir(f"{PATH_MODELS}/budget_{budget}"):
    pass
else:
    os.makedirs(f"{PATH_MODELS}/budget_{budget}")

if os.path.isdir(f"{PATH_MODELS}/budget_{budget}/random_seed_{random_seed}"):
    pass
else:
    os.makedirs(f"{PATH_MODELS}/budget_{budget}/random_seed_{random_seed}")

####################################################################################
####################################################################################
############################# Prepare the data #####################################
####################################################################################
####################################################################################
##### Parameters of General Bandits
var_model = ["RISK_SCORE", "EDUCATION", "MARRIAGE", "AMOUNT_CLUSTER", "AGE_CLUSTER"]
var_model_onehot = list(model_conversion.feature_names_in_)
var_base_reward_costs_logistic = ["constant"]
var_base_reward_costs_linear = ["amount_norm", "discount_base_norm", "discount", "constant"]
list_actions = [-1] + [10, 20, 35, 55, 80]

##### Prepare the data
dt_raw["discount"] = 0
dt_raw["constant"] = 1
np.random.seed(random_seed)
dt_env = dt_raw.sample(size_norm, replace=True).reset_index(drop = True)

context_list = dt_env[var_model].drop_duplicates().reset_index(drop = True)
context_list["index_context_approx"] = context_list.index.values + 1
dt_env = dt_env.merge(context_list, how = "left", on = var_model).reset_index(drop = True)

####################################################################################
####################################################################################
######### Simulate Optimal Static Policy and Linear, Logistic Bandits ##############
####################################################################################
####################################################################################

##### Simulate the Optimal Static Policy

policy_optim_env = dt_env[["index_context"]].merge(policy_optim, how = "left", on = "index_context").reset_index(drop = True)
policy_optim_env = policy_optim_env[list_actions].values

dict_mapping_actions = {i: list_actions[i] for i in range(len(list_actions))}
dt_env_with_policy_optim = deepcopy(dt_env)
dt_env_with_policy_optim["optim_static_action"] = benchmark_sample_action(policy_optim_env, dict_mapping_actions, seed = random_seed)
print(dt_env_with_policy_optim["optim_static_action"].value_counts(normalize = True))

dt_env_with_policy_optim = benchmark_cal_reward_costs(dt_env_with_policy_optim, model_conversion, budget, random_seed, apply_null = True)
print(f"Realized Reward: {np.cumsum(dt_env_with_policy_optim['reward']).max()}")
print(f"Realized Cost2: {np.cumsum(dt_env_with_policy_optim['cost2']).max()}")
print(f"Realized Cost1: {np.cumsum(dt_env_with_policy_optim['cost1']).max()}")

##### Export the simulation results of Optimal Static Policy

var_keep_optim = ["ID", "index_context", "index_context_approx", "optim_static_action", "pred_conv_optim_static_action",
                  "conversion_optim_static_action", "reward", "cost2", "cost1"]

dt_env_with_policy_optim = dt_env_with_policy_optim[var_keep_optim]

dump(dt_env_with_policy_optim, f"{PATH_MODELS}/budget_{budget}/random_seed_{random_seed}/policy_optim_static.pkl")

##### Simulate the Linear and Logistic Bandits and export the simulation results

for UCB_multiply_ in [0.025, 0.1, 0.3]:
    eta_oco_ = dict_eta_OCO[str(UCB_multiply_)]
    UCB_multiply_str_  = "0" + str(UCB_multiply_)[2:]
    
    print(UCB_multiply_str_)
    print(eta_oco_)
    
    dict_logistic_bandits_ = \
        {"var_rate": "interest_rate", "var_context": var_model, "var_model_onehot": var_model_onehot,
         "var_base_reward_costs": var_base_reward_costs_logistic, "list_actions": list_actions, "seed": random_seed,
         "model_conversion": model_conversion, "norm_costs1": 7, "T": size_norm, "budget": budget,
         "lmd": lmd_logistic, "verbose": True, "UCB_multiply": UCB_multiply_, "n_random_action": 50}
    
    obj_logistic_bandits_ = logistic_bandits(dict_logistic_bandits_, dt_env)
    obj_logistic_bandits_.run_simulation()
    
    dump(obj_logistic_bandits_,
         f"{PATH_MODELS}/budget_{budget}/random_seed_{random_seed}/logistic_bandits_C{UCB_multiply_str_}.pkl")
    
    dict_linear_bandits_ = \
        {"var_rate": "interest_rate", "var_context": var_model, "var_model_onehot": var_model_onehot,
         "var_base_reward_costs": var_base_reward_costs_linear, "eta_oco":eta_oco_, "seed":random_seed ,
         "list_actions": list_actions, "model_conversion": model_conversion, "Z": Z_linear, "norm_costs1": 7,
         "T": size_norm, "budget": budget, "lmd": lmd_linear, "verbose": True,
         "UCB_multiply": UCB_multiply_, "n_random_action": 50}
    
    obj_linear_bandits_ = linear_bandits(dict_linear_bandits_, dt_env)
    obj_linear_bandits_.run_simulation()
    dump(obj_linear_bandits_, f"{PATH_MODELS}/budget_{budget}/random_seed_{random_seed}/linear_bandits_C{UCB_multiply_str_}.pkl")
        
    print("------------------------------")

